import numpy


class ScoreParam:
    def __init__(self, match, mismatch, gap_open, gap_extend):
        self.match = match
        self.mismatch = mismatch
        self.gap_open = gap_open
        self.gap_extend = gap_extend

    def __str__(self):
        return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"


class SeqGraphAlignment(object):
    __default_score = ScoreParam(1, -3, -2, -1)

    def __init__(
        self,
        sequence,
        graph,
        fastMethod=True,
        globalAlign=False,
        score_params=__default_score,
        *args,
        **kwargs,
    ):
        self.score = score_params
        self.sequence = sequence
        self.graph = graph
        self.stringidxs = None
        self.nodeidxs = None
        self.globalAlign = globalAlign
        if fastMethod:
            matches = self.alignStringToGraphFast(*args, **kwargs)
        else:
            matches = self.alignStringToGraphSimple(*args, **kwargs)
        self.stringidxs, self.nodeidxs = matches

    def alignmentStrings(self):
        return (
            "".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
            "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
        )

    def matchscore(self, c1, c2):
        if c1 == c2:
            return self.score.match
        else:
            return self.score.mismatch

    def matchscoreVec(self, c, v):
        return numpy.where(v == c, self.score.match, self.score.mismatch)

    def prevIndices(self, node, nodeIDtoIndex):
        prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
        if not prev:
            prev = [-1]
        return prev

    def initializeDynamicProgrammingData(self):
        l1 = self.graph.nNodes
        l2 = len(self.sequence)

        nodeIDtoIndex = {}
        nodeIndexToID = {-1: None}
        ni = self.graph.nodeiterator()
        for index, node in enumerate(ni()):
            nodeIDtoIndex[node.ID] = index
            nodeIndexToID[index] = node.ID

        scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)

        if self.globalAlign:
            # M[0, i] = -inf
            scores[0, 0, :] = [
                -1000000000 for i in range(l2+1)
            ]
            scores[0, 0, 0] = 0
            # X[0, i] = gap_open + i * gap_extend
            scores[1, 0, :] = [
                self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1) 
            ]
            scores[1, 0, 0] = -1000000000
            # Y[0, i] = -inf
            scores[2, 0, :] = [
                -1000000000 for i in range(l2+1)
            ]

            ni = self.graph.nodeiterator()
            # After topology sort, the predcessors will have index less than the current node
            for index, node in enumerate(ni()):
                scores[0, index + 1, 0] = -1000000000
                scores[1, index + 1, 0] = -1000000000
                prevIdxs = self.prevIndices(node, nodeIDtoIndex)
                best = scores[2 ,prevIdxs[0] + 1, 0]
                for prevIdx in prevIdxs:
                    best = max(best, scores[2, prevIdx + 1, 0])
                # If we have no predecessors, we start the gap 
                if prevIdxs == [-1]:
                    scores[2, index + 1, 0] =  self.score.gap_open + self.score.gap_extend
                else:
                    scores[2, index + 1, 0] = best + self.score.gap_extend

        # 3D Backtracking
        backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
        backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
        backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)

        return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx

    def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
        besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
        #Storing best matrices for each [i,j]
        scores_arr = numpy.array(scores)
        max_m = numpy.argmax(scores_arr, axis=0)
        
        if self.globalAlign:
            ni = self.graph.nodeiterator()
            # Finding the best node to start from
            terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
            print(terminalIndices)
            besti = terminalIndices[0] + 1
            bestscore = scores[max_m[besti, bestj], besti, bestj]
            for i in terminalIndices[1:]:
                score = scores[max_m[i + 1, bestj], i + 1, bestj]
                if score > bestscore:
                    bestscore, besti = score, i + 1
            bestm = max_m[besti, bestj]

        matches = []
        strindexes = []

        while (besti != 0 or bestj != 0):
            nextm, nexti, nextj,  = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
            curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
            
            if bestm == 0:
                matches.insert(0, curnodeidx)
                strindexes.insert(0, curstridx)
            elif bestm == 1:
                matches.insert(0, None)
                strindexes.insert(0, curstridx)
            else:
                matches.insert(0, curnodeidx)
                strindexes.insert(0, None)

            bestm, besti, bestj = nextm, nexti, nextj

        return strindexes, matches
